{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## 自动生成聊天（第一句）的尝试\n",
        "\n",
        "这个脚本由李鲁鲁开发， 属于[Chat凉宫春日](https://github.com/LC1332/Chat-Haruhi-Suzumiya) \n",
        "\n",
        "用于研究是否能够基于GPT生成大量的对话数据\n",
        "\n",
        "**Chat凉宫春日**是模仿凉宫春日等一系列动漫人物，使用近似语气、个性和剧情聊天的语言模型，\n",
        "\n",
        "<details>\n",
        "  <summary> 由李鲁鲁，冷子昂，闫晨曦，封小洋等开发。 </summary>\n",
        "\n",
        "李鲁鲁发起了项目，并完成了最早的版本，在多个微信群实现了测试。\n",
        "\n",
        "冷子昂参与了早期Gradio的开发，并且参与了后端和前端的选型\n",
        "\n",
        "闫晨曦将李鲁鲁的notebook重构为app.py\n",
        "\n",
        "封小洋进行了中文转日文模型的选型\n",
        "\n",
        "</details>\n",
        "\n",
        "这个脚本主要是利用-1脚本已经抽取的关键词，两个jsonl文件\n",
        "\n",
        "来进行工作\n",
        "\n",
        "- [x] 组织句子生成的prompt\n",
        "- [x] 引入故事中的关键词\n",
        "- [x] 检验返回的jsonl是否合理\n",
        "- [x] 将生成结果 存储到google drive\n",
        "- [x] 合成一个大的jsonl\n"
      ],
      "metadata": {
        "id": "fqx4clSG3Z41"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Eob70uqD3RyJ"
      },
      "outputs": [],
      "source": [
        "#@title 安装环境\n",
        "! pip -q install openai gradio transformers tiktoken langchain gradio"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 输入OpenAI Token"
      ],
      "metadata": {
        "id": "MJiJrcm65o6r"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import openai\n",
        "\n",
        "openai.api_key = 'sk-lfrdoJKjlG' # 在这里输入你的OpenAI API Token\n",
        "os.environ[\"OPENAI_API_KEY\"] = openai.api_key "
      ],
      "metadata": {
        "id": "81basUOe5khn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "<details>\n",
        "  <summary> 让GPT老师写一个jsonl读取 </summary>\n",
        "\n",
        "  实现一个python函数，输入一个文件名，是一个jsonl文件\n",
        "  每一行是一个json（会包含中文）\n",
        "  将这个jsonl解析到一个list并返回\n",
        "\n",
        "</details>"
      ],
      "metadata": {
        "id": "FsBXC2bm51ZJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 实现parse_jsonl_file函数\n",
        "\n",
        "import json\n",
        "\n",
        "def parse_jsonl_file(file_name):\n",
        "    json_list = []\n",
        "    with open(file_name, 'r', encoding='utf-8') as f:\n",
        "        for line in f:\n",
        "            json_obj = json.loads(line.strip())\n",
        "            json_list.append(json_obj)\n",
        "    return json_list"
      ],
      "metadata": {
        "cellView": "form",
        "id": "AccAqMsz5wkZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fname_story_data = '/content/all_story_keywords.jsonl'\n",
        "\n",
        "story_keywords = parse_jsonl_file(fname_story_data )\n",
        "\n",
        "fname_chat_data = '/content/all_chat_datas.jsonl'\n",
        "\n",
        "chat_datas = parse_jsonl_file(fname_chat_data)"
      ],
      "metadata": {
        "id": "SCFFa_hR6Qlk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 统一story和chat中的keywords字段（前者本来为entity）\n",
        "\n",
        "def replace_entity_with_keywords(story_keywords):\n",
        "    for item in story_keywords:\n",
        "        if \"Entity\" in item:\n",
        "            item[\"keywords\"] = item.pop(\"Entity\")\n",
        "    return story_keywords\n",
        "\n",
        "story_keywords = replace_entity_with_keywords(story_keywords)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "_1wbipGi7JNb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 去除停词关键词\n",
        "\n",
        "stop_words = ['春日','阿虚','凉宫','凉宫春日']\n",
        "\n",
        "\n",
        "def remove_stop_words(data, stop_words):\n",
        "    stop_words_set = set(stop_words) # 转换为set方便查找\n",
        "    for item in data:\n",
        "        # if \"keywords\" in item: # 判断关键词列表是否存在\n",
        "        item[\"keywords\"] = [w for w in item[\"keywords\"] if w not in stop_words_set]\n",
        "    return data\n",
        "\n",
        "chat_datas = remove_stop_words(chat_datas, stop_words)\n",
        "story_keywords = remove_stop_words(story_keywords, stop_words)\n"
      ],
      "metadata": {
        "id": "GhIgJ60T7Tog"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(len(story_keywords))\n",
        "print(len(chat_datas))\n",
        "\n",
        "print(story_keywords[0])\n",
        "print(chat_datas[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9IBYs6Br6mrS",
        "outputId": "50bc23db-832d-4715-e2bb-0a2a2bbf26b5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "48\n",
            "556\n",
            "{'keywords': ['电脑', '资讯化的时代', '拍立得', '作战计划', '把握机会']}\n",
            "{'role_A': '阿虚', 'role_B': '春日', 'query': '「今天在计算机课上老师教了我写Python!」', 'response': '「哦？Python？那你能不能帮我写一个程序啊？」', 'keywords': ['计算机课', 'Python']}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 让我们来测试Prompt！"
      ],
      "metadata": {
        "id": "WMSwoHZI8S8Q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 定义organize_samples 和 list_to_string\n",
        "\n",
        "from typing import List, Dict, Tuple\n",
        "import numpy as np\n",
        "\n",
        "def list_to_string(lst):\n",
        "    result = ''\n",
        "    for item in lst:\n",
        "        result += str(item) + '\\n'\n",
        "    return result\n",
        "\n",
        "def organize_samples(sel_chat_datas: List[Dict[str, str]]) -> Tuple[List[Dict], List[Dict], List[str]]:\n",
        "    # stop_words = ['春日', '阿虚', '凉宫', '凉宫春日']\n",
        "    sample_input = []\n",
        "    sample_output = []\n",
        "    all_keywords = set()\n",
        "    for element in sel_chat_datas:\n",
        "        keywords = element['keywords']  # [kw for kw in element['keywords'] if kw not in stop_words]\n",
        "        np.random.shuffle(keywords)\n",
        "        sample_input.append({'keywords': keywords})\n",
        "        output_element = {\n",
        "            'keywords': keywords,\n",
        "            'role': element['role_A'],\n",
        "            'text': element['query'],\n",
        "        }\n",
        "        sample_output.append(output_element)\n",
        "        for kw in keywords:\n",
        "            all_keywords.add(kw)\n",
        "    return sample_input, sample_output, list(all_keywords)"
      ],
      "metadata": {
        "id": "g_SvKF5G9J3n"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 测试organize_samples\n",
        "\n",
        "sel_chat_data = [chat_datas[i] for i in range(3)]\n",
        "\n",
        "sample_input, sample_output, _ = organize_samples(sel_chat_data)\n",
        "\n",
        "print(list_to_string(sample_input))\n",
        "# print('\\n')\n",
        "print(list_to_string(sample_output))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7JH9Q57y9mYP",
        "outputId": "1585ef16-d168-4d87-fb8d-39631d3f703c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'keywords': ['Python', '计算机课']}\n",
            "{'keywords': ['什么样的程序', '写程序']}\n",
            "{'keywords': ['程序', '赚很多钱', '预测彩票']}\n",
            "\n",
            "{'keywords': ['Python', '计算机课'], 'role': '阿虚', 'text': '「今天在计算机课上老师教了我写Python!」'}\n",
            "{'keywords': ['什么样的程序', '写程序'], 'role': '阿虚', 'text': '「你想写一个什么样的程序呢？」'}\n",
            "{'keywords': ['程序', '赚很多钱', '预测彩票'], 'role': '阿虚', 'text': '「如果有一个能预测彩票的程序，我们岂不是能赚很多钱？」'}\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 后面的目标\n",
        "\n",
        "我写到这里觉得有点头秃\n",
        "\n",
        "我们要组一个one-shot或者说two-shot的prompt\n",
        "\n",
        "这里关键是组织新的input很令人头痛\n",
        "\n",
        "因为我希望新的input是这样的，keywords不出现任意原来sample input的列表中，\n",
        "\n",
        "这样太耦合了，我们先做一个foo_input，然后把prompt组装的函数给写了\n",
        "\n"
      ],
      "metadata": {
        "id": "i6kQfVrhAEB9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 一个之后会废弃的foo_sample和foo_input函数\n",
        "\n",
        "import random \n",
        "\n",
        "n = len(chat_datas)\n",
        "sel_all = random.sample(range(0, n), 20)\n",
        "\n",
        "sel_sample = sel_all[:10]\n",
        "sel_input = sel_all[10:]\n",
        "\n",
        "def foo_sample():\n",
        "    sel_chat_data = [chat_datas[i] for i in sel_sample]\n",
        "\n",
        "    sample_input, sample_output, sample_keywords = organize_samples(sel_chat_data)\n",
        "\n",
        "    return sample_input, sample_output, sample_keywords\n",
        "    \n",
        "\n",
        "def foo_input(sample_keywords):\n",
        "    sel_chat_data = [chat_datas[i] for i in sel_input]\n",
        "\n",
        "    sample_input, _ , _ = organize_samples(sel_chat_data)\n",
        "\n",
        "    return sample_input"
      ],
      "metadata": {
        "id": "ftiVpPapBsPL",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "TANSFER_PROMPT = \"\"\"\n",
        "根据keywords的内容补全text\n",
        "text为对于凉宫春日剧情的一些讨论问题，role不可以是春日或者凉宫春日\n",
        "role可以是阿虚、朝比奈、老师等凉宫春日中，非春日的其他角色\n",
        "role也可以是任意其他动漫中的角色\n",
        "用一致性的语言风格，根据每行中的json内容，根据keywords中的关键字，补全text的内容。\n",
        "\"\"\""
      ],
      "metadata": {
        "id": "rWFqz2fPGfME"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 定义生成prompt\n",
        "\n",
        "from langchain.chat_models import ChatOpenAI\n",
        "\n",
        "from langchain.prompts.chat import (\n",
        "    ChatPromptTemplate,\n",
        "    SystemMessagePromptTemplate,\n",
        "    AIMessagePromptTemplate,\n",
        "    HumanMessagePromptTemplate,\n",
        ")\n",
        "from langchain.schema import (\n",
        "    AIMessage,\n",
        "    HumanMessage,\n",
        "    SystemMessage\n",
        ")\n",
        "\n",
        "chatModel = ChatOpenAI(temperature=0.1, max_tokens = 2000)\n",
        "\n",
        "sample_input, sample_output, sample_keywords = foo_sample()\n",
        "query_input = foo_input(sample_keywords)\n",
        "\n",
        "def generate_with_keywords( sample_input, sample_output, query_input ):\n",
        "\n",
        "    div_k = 4\n",
        "    input1 = list_to_string( sample_input[:div_k] )\n",
        "    output1 = list_to_string( sample_output[:div_k] )\n",
        "    input2 = list_to_string( sample_input[div_k:] )\n",
        "    output2 = list_to_string( sample_output[div_k:] )\n",
        "\n",
        "    query = list_to_string(query_input)\n",
        "\n",
        "    messages = [\n",
        "        SystemMessage(content=TANSFER_PROMPT),\n",
        "        HumanMessage(content=input1),\n",
        "        AIMessage(content=output1),\n",
        "        HumanMessage(content=input2),\n",
        "        AIMessage(content=output2),\n",
        "        HumanMessage(content=query)\n",
        "    ]\n",
        "\n",
        "    return_msg = chatModel(messages)\n",
        "\n",
        "    return return_msg.content"
      ],
      "metadata": {
        "id": "pDHsUh8m8dOc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "response = generate_with_keywords( sample_input, sample_output, query_input )\n",
        "print(response)"
      ],
      "metadata": {
        "id": "-mvP4M0Zz5Hm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "接下来我们要开始真正组织的input\n",
        "\n",
        "<details>\n",
        "  <summary> 让GPT老师开发一下数据类 </summary>\n",
        "\n",
        "我希望实现一个python的类 DataLoader。这个python类由一个list初始化。并且内部会记录list的个数n\n",
        "\n",
        "初始化的时候，会自动建立一个长度为n的shuffle_id,用来从整个list随机取元素，并把currend_id赋值为0\n",
        "\n",
        "这个类有一个方法getData()， 每次会根据current_id依次返回一个数据（序号为shuffle_id[current_id]的数据）\n",
        "\n",
        "当返回了n个数据的时候，DataLoader会重新shuffle数据的顺序。\n",
        "\n",
        "并且可以额外设定一个数据k(默认值为10)，当n > 2k时，每次重新shuffle会保证新的shuffe_id序列的前k-1个元素和上一次shuffe_id的后k-1个元素之间没有重复元素。这样保证在连续取k个数据时，总是不重复\n",
        "\n",
        "\n",
        "</details>"
      ],
      "metadata": {
        "id": "DFYfZrMXoPrX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 建立DataLoader类\n",
        "\n",
        "import random\n",
        "\n",
        "class DataLoader:\n",
        "    def __init__(self, data, k=10):\n",
        "        self.data = data\n",
        "        self.n = len(data)\n",
        "        self.k = k\n",
        "        self.current_id = 0\n",
        "        self.shuffle_id = list(range(self.n))\n",
        "        random.shuffle(self.shuffle_id)\n",
        "        self.previous_tail = self.shuffle_id[-self.k+1:]\n",
        "\n",
        "    def shuffle(self):\n",
        "        if self.n <= 2 * self.k:\n",
        "            random.shuffle(self.shuffle_id)\n",
        "        else:\n",
        "            random.shuffle(self.shuffle_id)\n",
        "            head = self.shuffle_id[:self.k-1]\n",
        "            flag = True\n",
        "            count = 0\n",
        "\n",
        "            min_ovlp_num = 999\n",
        "            min_ovlp_plan = []\n",
        "\n",
        "            while count < 10 and flag == True:\n",
        "                count = count + 1\n",
        "                inverse_flag = False\n",
        "                ovlp_num = 0\n",
        "                for id in head:\n",
        "                    if id in self.previous_tail:\n",
        "                        ovlp_num = ovlp_num + 1\n",
        "                        inverse_flag = True\n",
        "\n",
        "                if ovlp_num < min_ovlp_num:\n",
        "                    min_ovlp_num = ovlp_num\n",
        "                    min_ovlp_plan = self.shuffle_id.copy()\n",
        "\n",
        "                if False == inverse_flag:\n",
        "                    flag = False\n",
        "                    break\n",
        "\n",
        "                random.shuffle(self.shuffle_id)\n",
        "                head = self.shuffle_id[:self.k-1]\n",
        "\n",
        "            # print('shuffle test time ', count, ' min ovlp = ', min_ovlp_num)\n",
        "\n",
        "            if min_ovlp_num > 0:\n",
        "                self.shuffle_id = min_ovlp_plan\n",
        "\n",
        "            head = self.shuffle_id[self.k-1:]\n",
        "            tail = self.shuffle_id[-self.k+1:]\n",
        "\n",
        "            self.shuffle_id = head + self.shuffle_id[:self.k-1] + tail\n",
        "            random.shuffle(self.shuffle_id)\n",
        "            self.previous_tail = tail\n",
        "\n",
        "    def get_data(self):\n",
        "        if self.current_id >= self.n:\n",
        "            self.shuffle()\n",
        "            self.current_id = 0\n",
        "        data = self.data[self.shuffle_id[self.current_id]]\n",
        "        self.current_id += 1\n",
        "        return data"
      ],
      "metadata": {
        "cellView": "form",
        "id": "TxY8jlSloOOB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_story = DataLoader(story_keywords, 10)\n",
        "\n",
        "data_chat_as_story = DataLoader(chat_datas, 10)\n",
        "\n",
        "data_chat = DataLoader(chat_datas, 10)\n",
        "\n",
        "for i in range(900):\n",
        "    data = data_story.get_data()\n",
        "    if i % 300 == 0:\n",
        "        print(data)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TJ6vHaKYpy6j",
        "outputId": "dc43c889-9f9b-4201-9281-0eb7ad4acc3f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'keywords': ['侦探推理小说迷', '推理研究会']}\n",
            "{'keywords': ['长门有希', '社团教室']}\n",
            "{'keywords': ['名侦探', '小众文化类同好会']}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "save_path = \"/content/drive/MyDrive/GPTData/Haruhi-AutoFirst/\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SI5SlOBvy8Xs",
        "outputId": "dee47773-d11a-4865-fae3-30794fd71559"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import datetime\n",
        "import os\n",
        "\n",
        "# save_path = \"/content/drive/MyDrive/GPTData/Haruhi-AutoFirst/\"\n",
        "\n",
        "def save_to_file(response, save_path):\n",
        "    # 获取当前时间并转换成字符串作为文件名\n",
        "    now = datetime.datetime.now()\n",
        "    timestamp_str = now.strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
        "\n",
        "    # 拼接完整的文件路径\n",
        "    file_path = os.path.join(save_path, f\"{timestamp_str}.txt\")\n",
        "\n",
        "    # 如果文件已经存在，则在文件名尾部加上一个随机字符串\n",
        "    while os.path.exists(file_path):\n",
        "        random_suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))\n",
        "        file_path = os.path.join(save_path, f\"{timestamp_str}-{random_suffix}.txt\")\n",
        "\n",
        "    # 将response写入文件\n",
        "    with open(file_path, \"w\", encoding=\"utf-8\") as f:\n",
        "        f.write(response)"
      ],
      "metadata": {
        "id": "YvzgGADkzk-b"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "batch_size = 10\n",
        "\n",
        "for iter_time in tqdm(range(700),desc='autoGenerating'):\n",
        "\n",
        "    chat_data = []\n",
        "\n",
        "    for _ in range(batch_size):\n",
        "        chat_data.append( data_chat.get_data() )\n",
        "\n",
        "    sample_input, sample_output, sample_keywords = organize_samples(chat_data)\n",
        "\n",
        "    #这里我们还要组织query_input\n",
        "\n",
        "    query_input = []\n",
        "\n",
        "    for input in sample_input:\n",
        "        target_n = len( input['keywords'] )\n",
        "        target_n = max(2, target_n )\n",
        "\n",
        "        count_time = 0\n",
        "        max_len = -999\n",
        "        max_len_plan = []\n",
        "\n",
        "        while count_time < 15:\n",
        "            #随机抽取一个story的keyword\n",
        "            count_time = count_time + 1\n",
        "            if iter_time % 2 == 0:\n",
        "                story_keyword = data_story.get_data()\n",
        "            else:\n",
        "                story_keyword = data_chat_as_story.get_data()\n",
        "\n",
        "            filtered_keyword = [w for w in story_keyword[\"keywords\"] if w not in sample_keywords]\n",
        "            if len(filtered_keyword) >= target_n:\n",
        "                story_keyword['keywords'] = random.sample(filtered_keyword, min(target_n, len(filtered_keyword)))\n",
        "                break\n",
        "            else:\n",
        "                if len(filtered_keyword)>max_len:\n",
        "                    max_len = len(filtered_keyword)\n",
        "                    # story_keyword['keywords'] = filtered_keyword\n",
        "                    max_len_plan = filtered_keyword.copy()\n",
        "\n",
        "        if len(story_keyword['keywords'] ) < target_n:\n",
        "            story_keyword['keywords'] = max_len_plan\n",
        "            # print('use max len plan ', target_n - len(story_keyword['keywords'] ))\n",
        "        query_input.append( {'keywords':story_keyword['keywords']} )\n",
        "\n",
        "        for keyword in story_keyword['keywords']:\n",
        "            sample_keywords.append(keyword)\n",
        "\n",
        "    # response = generate_with_keywords( sample_input, sample_output, query_input )\n",
        "    try:\n",
        "        response = generate_with_keywords(sample_input, sample_output, query_input)\n",
        "    except Exception as e:\n",
        "        print(f\"An error occurred while running the script: {e}\")\n",
        "        break\n",
        "\n",
        "    save_to_file(response,save_path)\n",
        "\n",
        "    # if iter_time > 5:\n",
        "    #     break"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UeGPyAIOtiv2",
        "outputId": "c437886e-9598-444f-ba80-cf8abd2b30a2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "autoGenerating:   4%|▍         | 29/700 [21:53<8:24:27, 45.11s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 6acba8e1f9e85d735d9c04ba0c04a555 in your message.).\n",
            "autoGenerating:  14%|█▎        | 95/700 [1:09:50<7:43:22, 45.96s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f8fc9d94d7bcdaa1e01ec134376936e5 in your message.).\n",
            "autoGenerating:  34%|███▍      | 241/700 [3:01:05<5:50:26, 45.81s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 74d913082bd0fb0a62e365aa42c64297 in your message.).\n",
            "autoGenerating:  39%|███▉      | 275/700 [3:27:06<5:25:17, 45.92s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID c752d75aa5e8fb259769eba9814a7b34 in your message.).\n",
            "autoGenerating:  43%|████▎     | 302/700 [3:48:44<5:00:56, 45.37s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 4e2797304869263532ef9f69d37e0050 in your message.).\n",
            "autoGenerating:  53%|█████▎    | 374/700 [4:43:01<4:03:38, 44.84s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 9df1f7ebf763135286670f393f914d05 in your message.).\n",
            "autoGenerating:  64%|██████▎   | 446/700 [5:35:23<3:01:28, 42.87s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID f4e79d5876527575bfdc590d19953205 in your message.).\n",
            "autoGenerating:  75%|███████▌  | 526/700 [6:38:11<2:29:07, 51.42s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID bfef80138956765afcc703741e466766 in your message.).\n",
            "autoGenerating:  76%|███████▌  | 530/700 [6:42:10<2:35:58, 55.05s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 4e794d4dbe9311595c1b997b18d64972 in your message.).\n",
            "autoGenerating:  77%|███████▋  | 540/700 [6:50:49<2:13:32, 50.08s/it]WARNING:langchain.chat_models.openai:Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 713b193f20fabf0dfba7d20b8cef8116 in your message.).\n",
            "autoGenerating: 100%|██████████| 700/700 [8:57:45<00:00, 46.09s/it]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "看一下长度"
      ],
      "metadata": {
        "id": "wqc1C96PHbja"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "合成成一个jsonl文件\n",
        "\n",
        "再召唤GPT老师\n",
        "\n",
        "已知在python中\n",
        "\n",
        "save_path = \"/content/drive/MyDrive/GPTData/Haruhi-AutoFirst/\"\n",
        "\n",
        "save_name = \"/content/drive/MyDrive/GPTData/Haruhi_first_merge.jsonl\"\n",
        "\n",
        "save_path中存储了很多jsonl格式的文件，后缀名是.txt，每行是一个json，其中包含中文。\n",
        "\n",
        "对save_path中所有的.txt文件进行读取，逐行解析，\n",
        "\n",
        "如果解析成功，则append到一个list中\n",
        "如果解析失败，则继续解析下一个文件\n",
        "\n",
        "最后将list，以jsonl格式存储到save_name中"
      ],
      "metadata": {
        "id": "jSn4O95zzmrW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import json\n",
        "\n",
        "save_path = \"/content/drive/MyDrive/GPTData/Haruhi-AutoFirst/\"\n",
        "save_name = \"/content/drive/MyDrive/GPTData/Haruhi_first_merge.jsonl\"\n",
        "\n",
        "# 定义一个空列表，用于存储解析成功的json数据\n",
        "json_list = []\n",
        "\n",
        "# 遍历save_path目录下的所有txt文件\n",
        "for file_name in os.listdir(save_path):\n",
        "    if file_name.endswith(\".txt\"):\n",
        "        file_path = os.path.join(save_path, file_name)\n",
        "        with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
        "            # 逐行读取文件内容\n",
        "            for line in f:\n",
        "                # 尝试解析读取到的行数据\n",
        "                try:\n",
        "                    my_str = line.strip()\n",
        "                    my_str = my_str.replace(\"'\", \"\\\"\")\n",
        "                    json_data = json.loads(my_str)\n",
        "                    json_list.append(json_data)\n",
        "                except:\n",
        "                    my_str = line.strip()\n",
        "                    print('warining in line ', my_str)\n",
        "                    continue\n",
        "\n",
        "# 将解析成功的json数据以jsonl格式写入save_name文件中\n",
        "with open(save_name, \"w\", encoding=\"utf-8\") as f:\n",
        "    for json_data in json_list:\n",
        "        f.write(json.dumps(json_data, ensure_ascii=False) + \"\\n\")"
      ],
      "metadata": {
        "id": "GUZR1UobH_wY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(len(json_list))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hN94f64d0h4g",
        "outputId": "eee578b7-b551-4f00-f5dd-0ca1a5f696a5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "7079\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(chat_datas[0])\n",
        "print(chat_datas[1])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EpJwcGds0zIm",
        "outputId": "bffc195e-8e12-4248-dd76-6333ae7b5bba"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'role_A': '阿虚', 'role_B': '春日', 'query': '「今天在计算机课上老师教了我写Python!」', 'response': '「哦？Python？那你能不能帮我写一个程序啊？」', 'keywords': ['Python', '计算机课']}\n",
            "{'role_A': '阿虚', 'role_B': '春日', 'query': '「你想写一个什么样的程序呢？」', 'response': '「我想写一个能够预测未来的程序，可以预测天气、地震、彩票号码等等。」', 'keywords': ['什么样的程序', '写程序']}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "额外再做一个事情吧。把chat_datas做成一个更好的形式。"
      ],
      "metadata": {
        "id": "PNb4QCfj-PN7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with open('/content/temp_all_chat.jsonl','w', encoding=\"utf-8\") as f:\n",
        "    for chat in chat_datas:\n",
        "        temp_json1 = {'role':chat['role_A'],'text':chat['query']}\n",
        "        temp_json2 = {'role':chat['role_B'],'text':chat['response']}\n",
        "        f.write(json.dumps(temp_json1, ensure_ascii=False) + \"\\n\")\n",
        "        f.write(json.dumps(temp_json2, ensure_ascii=False) + \"\\n\")"
      ],
      "metadata": {
        "id": "NrN6tNj0-LTP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "hLTLHmLw-uq_"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}